module SIRQ

using Flux,
    DiffEqFlux,
    DifferentialEquations,
    XLSX,
    DataFrames,
    DiffEqSensitivity,
    Plots, Optim, OrdinaryDiffEq

using Random

nn = FastChain(FastDense(3, 20, relu), FastDense(20, 20, relu), FastDense(20, 1, sigmoid))
Random.seed!(1)
p_init = rand(521)/ 100
ts = 0
tend = 0
trains = 0
trainend = 0
x = nothing
y = nothing
u0 = nothing

function run(data, u0, params, tspan, trainspan)
    p = [params; p_init]
    global x = data[1]
    global y = data[2]
    global ts = tspan[1]
    global tend = tspan[2]
    global trains = trainspan[1]
    global trainend = trainspan[2]
    global prob = ODEProblem(SIRQ!, u0, (ts, tend), p)
    res = DiffEqFlux.sciml_train(loss_fun, p, BFGS(initial_stepnorm = 0.015), maxiters = 100)
end

function plot_states()
    t_step = ts:1:tend
    scatter(t_step, x, color = [1], label = "infected")
    scatter!(t_step, y, color = [2], label = "recovered")
    prediction = predict(cur_p)
    plot!(t_step, prediction[2, :], color = [3], label = "predicted infected", legend = false)
    plot!(t_step, prediction[3, :], color = [4], label = "predicted recovered")
end

function get_pred()
    return predict(cur_p)
end


function plot_quarantine()
    t_step = ts:1:tend
    prediction = predict(cur_p)
    q_strength = nn(prediction[2:4, :], cur_p)
    scatter(t_step, reshape(q_strength,  size(q_strength)[2]), ylims = (0, 1.5), legend = false)
end

function get_quarantine()
    t_step = ts:1:tend
    prediction = predict(cur_p)
    q_strength = nn(prediction[2:4, :], cur_p)
end


function SIRQ!(du, u, p, t)
    s, i, r, T = u
    n = s + i + r + T
    β = p[1]
    γ = p[2]
    nn_p = p[3:end]
    du[1] = ds = -β * s * i / n
    du[2] = di = β * s * i / n - (γ+ nn(u[2:4], p)[1]) * i
    du[3] = dr = γ * i
    du[4] = dT = nn(u[2:4], p)[1] * i
end

function predict(p)
    return Array(solve(prob, Tsit5(), u0=u0, p=p, saveat = ts:1:tend))
end

function loss_fun(p)
    global cur_p = p
    global index = 1
    prediction = predict(p)
    #print("infected: ")
    #print(round(sum((log.(prediction[2, trains:trainend]) - log.(x[trains:trainend])) .^ 2), digits=3))
    #print("   ")
    #print("recovered: ")
    #println(round(sum((log.(prediction[3, trains:trainend]) - log.(y[trains:trainend])) .^ 2), digits=3))
    se = sum((log.(prediction[2, trains:trainend]) - log.(x[trains:trainend])) .^ 2)
        + sum((log.(prediction[3, trains:trainend]) - log.(y[trains:trainend])) .^ 2)

    println(se)
    return se
end

end  # module
